"""
polar : polar coordinate
angle: the orientation of the original image
"""

import argparse
from collections import defaultdict

from torch import optim
from torch.utils.data import DataLoader
import wandb
from tqdm import trange

from disvae.models.anneal import get_anneal
from disvae.models.losses import get_loss_s
from disvae.models.vae import VAE
from disvae.models.encoders import get_encoder
from disvae.models.decoders import get_decoder
from exps.visualize import plot_reconstruct, plt_sample_traversal, plot_projection
from utils.visualize import *
from exps.patterns import data_generator
import numpy as np

from utils.helpers import set_seed

def get_preds(model, dl,device="cuda"):
    model.eval()
    preds = []
    targets = []
    for img, label in dl:
        with torch.no_grad():
            img = img.view(-1, 1, 64, 64).to(device)
            mu, _ = model.encoder(img)
            preds.append(mu)
            targets.append(label)
    preds, targets = torch.cat(preds), torch.cat(targets)
    model.train()
    return preds, targets


def train(dl, model, loss_f, epochs,eval_iter=40,device="cuda"):
    opt = optim.Adam(model.parameters(), 5e-4)
    for e in trange(epochs):
        storer = defaultdict(list)
        loss_set = []
        for img, label in dl:
            img = img.to(device)
            recon_batch, latent_dist, latent_sample = model(img)
            loss = loss_f(img, recon_batch, latent_dist, model.training,
                          storer, latent_sample=latent_sample)

            opt.zero_grad()
            loss.backward()
            opt.step()

            loss_set.append(loss.item())
        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)
        storer['loss'] = np.mean(loss_set)

        if e % eval_iter == 0:
            wandb.log({
                'recon': wandb.Image(recon_batch[0, 0]),
                'img': wandb.Image(img[0, 0])
            })
            kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' == k[:3]])
            _, index = kl_loss.sort(descending=True)
            mu, logvar = latent_dist

            points = mu.data[:, index[:3]].cpu()
            colors = label
            fig, axes = plt.subplots(2, 3)
            for i in range(2):  # factor
                for j in range(3):  # variable
                    axes[i, j].scatter(points[:, j].numpy(), colors[:, i].numpy(), s=0.2)
                    axes[i, j].set_title(f'{i + 1},{index[j]}')
            storer['correlated'] = wandb.Image(fig)
            plt.close()
        storer['epoch'] = e
        wandb.log(storer, sync=False)
    return storer

if __name__ == "__main__":
    hyperparameter_defaults = dict(
        batch_size=256,
        learning_rate=0.0005,
        epochs=401,
        dimension=6,
        loss='betaH',
        beta=3,
        img_id=0,
        polar=False,
        angle=45,
        random_seed=114
    )
    parser = argparse.ArgumentParser()
    for key, value in hyperparameter_defaults.items():
        parser.add_argument(f'--{key}', default=value, type=type(value))
    args = parser.parse_args()
    config = args

    set_seed(config.random_seed)
    wandb.init(project="exps", config=config, group='translation')

    # generate data
    img_id = config.img_id
    epochs = config.epochs
    dim = config.dimension
    beta = config.beta

    img = torch.load(f'patterns/{img_id}.pat')
    img = data_generator.rotate(img, config.angle)

    dataset = None
    if config.polar:
        dataset = data_generator.PolarTranslation(img)
    else:
        dataset = data_generator.Translation(img)
    dl = DataLoader(dataset,
                    num_workers=0, batch_size=config.batch_size, shuffle=True)

    vae = VAE((1, 64, 64), get_encoder('Burgess'), get_decoder('Burgess'), dim)
    vae.cuda()
    vae.train()

    # wandb.watch(vae, log_freq=10)
    iterations = len(dl) * epochs
    anneal = get_anneal('monotonic', iterations, 60, 1)

    loss_f = get_loss_s(config.loss, anneal, config.beta, len(dl.dataset))

    storer = train(dl, vae, loss_f, epochs)

    kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' == k[:3]])
    _, index = kl_loss.sort(descending=True)
    print(index)
    # torch.save(vae.state_dict(), 'tmp.pt')

    preds, targets = get_preds(vae, dl)
    data = torch.cat([preds.cpu(), targets], 1).cpu()

    vae.eval()
    vae.cpu()
    fig = plot_reconstruct(dataset.imgs, (3, 2), vae)
    wandb.log({'reconstruction': wandb.Image(fig)})

    fig = plt_sample_traversal(dataset.imgs[:1], vae, 7, index[:2], r=2)
    wandb.log({'traversal': wandb.Image(fig)})
    imgs = dataset.imgs.reshape(40, 40, 64, 64)
    fig = plot_projection(imgs[::4, ::4], vae, dim=index[:2])
    wandb.log({'projection': wandb.Image(fig)})

    table = wandb.Table(data=data.tolist(), columns=["0", "1", "2", "3", "4", "5", "c1", "c2"])
    wandb.log({'embedding': table})
    plt.show()